4.2 Custom Middleware
本节介绍如何创建自定义中间件,实现特定的业务逻辑。
概述
当内置中间件不满足需求时,可以创建自定义中间件。LangChain 提供两种实现方式:
- 装饰器方式:简单、单钩子
- 类方式:复杂、多钩子
钩子类型
Node-style Hooks(节点钩子)
在执行的特定点顺序运行:
| 钩子 | 触发时机 | 调用次数 |
|---|---|---|
before_agent | Agent 启动前 | 每次调用 1 次 |
before_model | 每次模型调用前 | 每次模型调用 |
after_model | 每次模型响应后 | 每次模型调用 |
after_agent | Agent 完成后 | 每次调用 1 次 |
Wrap-style Hooks(包装钩子)
控制处理器的调用方式,像函数调用一样嵌套:
| 钩子 | 包装目标 |
|---|---|
wrap_model_call | 模型调用 |
wrap_tool_call | 工具调用 |
装饰器方式
最简单的创建方式,适合单个钩子:
before_model
在每次模型调用前执行:
python
from langchain.agents import before_model, AgentState, Runtime
@before_model
def check_message_limit(state: AgentState, runtime: Runtime):
"""检查消息数量限制"""
if len(state["messages"]) >= 50:
print("消息过多,终止执行")
return {"jump_to": "end"} # 提前终止
return None
agent = create_agent(
"gpt-4o",
tools=[my_tools],
middleware=[check_message_limit]
)after_model
在每次模型响应后执行:
python
from langchain.agents import after_model
@after_model
def log_response(state: AgentState, response, runtime: Runtime):
"""记录模型响应"""
print(f"模型响应: {response.content[:100]}...")
return None
agent = create_agent(
"gpt-4o",
tools=[my_tools],
middleware=[log_response]
)before_agent
在 Agent 启动前执行(每次调用只执行一次):
python
from langchain.agents import before_agent
@before_agent
def initialize_session(state: AgentState, runtime: Runtime):
"""初始化会话"""
state["session_id"] = generate_session_id()
state["start_time"] = time.time()
print(f"Session started: {state['session_id']}")
return Noneafter_agent
在 Agent 完成后执行:
python
from langchain.agents import after_agent
@after_agent
def cleanup_session(state: AgentState, runtime: Runtime):
"""清理会话"""
duration = time.time() - state.get("start_time", 0)
print(f"Session ended. Duration: {duration:.2f}s")
return Nonewrap_model_call
包装模型调用:
python
from langchain.agents import wrap_model_call
@wrap_model_call
def time_model_call(state, runtime, call_next):
"""计时模型调用"""
start = time.time()
result = call_next() # 执行实际的模型调用
duration = time.time() - start
print(f"Model call took {duration:.2f}s")
return resultwrap_tool_call
包装工具调用:
python
from langchain.agents import wrap_tool_call
@wrap_tool_call
def validate_tool_call(state, tool_name, tool_args, runtime, call_next):
"""验证工具调用"""
print(f"Calling tool: {tool_name}")
print(f"Arguments: {tool_args}")
# 可以在这里进行验证或修改参数
if tool_name == "dangerous_tool":
raise ValueError("This tool is not allowed")
result = call_next() # 执行实际的工具调用
print(f"Tool result: {result}")
return result类方式
适合复杂的多钩子中间件:
基本结构
python
from langchain.agents import AgentMiddleware, AgentState, Runtime
class LoggingMiddleware(AgentMiddleware):
"""日志记录中间件"""
def __init__(self, log_level: str = "info"):
self.log_level = log_level
def before_agent(self, state: AgentState, runtime: Runtime):
"""Agent 启动时"""
print(f"[{self.log_level}] Agent starting...")
return None
def before_model(self, state: AgentState, runtime: Runtime):
"""模型调用前"""
msg_count = len(state["messages"])
print(f"[{self.log_level}] Model call with {msg_count} messages")
return None
def after_model(self, state: AgentState, response, runtime: Runtime):
"""模型响应后"""
print(f"[{self.log_level}] Model responded")
return None
def after_agent(self, state: AgentState, runtime: Runtime):
"""Agent 完成时"""
print(f"[{self.log_level}] Agent completed")
return None
# 使用
agent = create_agent(
"gpt-4o",
tools=[my_tools],
middleware=[LoggingMiddleware(log_level="debug")]
)复杂示例:分析中间件
python
class AnalyticsMiddleware(AgentMiddleware):
"""分析追踪中间件"""
def __init__(self, analytics_client):
self.client = analytics_client
self.metrics = {
"model_calls": 0,
"tool_calls": 0,
"tokens_used": 0,
"start_time": None,
}
def before_agent(self, state, runtime):
self.metrics["start_time"] = time.time()
self.client.track("agent_started", {
"thread_id": runtime.config.get("thread_id")
})
return None
def after_model(self, state, response, runtime):
self.metrics["model_calls"] += 1
if hasattr(response, "response_metadata"):
usage = response.response_metadata.get("usage", {})
self.metrics["tokens_used"] += usage.get("total_tokens", 0)
return None
def wrap_tool_call(self, state, tool_name, tool_args, runtime, call_next):
self.metrics["tool_calls"] += 1
start = time.time()
result = call_next()
duration = time.time() - start
self.client.track("tool_called", {
"tool_name": tool_name,
"duration": duration,
})
return result
def after_agent(self, state, runtime):
total_time = time.time() - self.metrics["start_time"]
self.client.track("agent_completed", {
"model_calls": self.metrics["model_calls"],
"tool_calls": self.metrics["tool_calls"],
"tokens_used": self.metrics["tokens_used"],
"total_time": total_time,
})
return None自定义状态 Schema
扩展 Agent 状态以存储自定义数据:
python
from langchain.agents import AgentState, create_agent
from typing import Optional
class CustomAgentState(AgentState):
"""自定义 Agent 状态"""
user_id: str
session_id: Optional[str] = None
request_count: int = 0
custom_data: dict = {}
@before_model
def track_requests(state: CustomAgentState, runtime):
"""追踪请求数量"""
state["request_count"] = state.get("request_count", 0) + 1
return None
agent = create_agent(
"gpt-4o",
tools=[my_tools],
state_schema=CustomAgentState,
middleware=[track_requests]
)Agent 跳转
使用 jump_to 控制 Agent 流程:
python
@before_model
def check_safety(state, runtime):
"""安全检查,必要时提前终止"""
last_message = state["messages"][-1].content
# 检测危险内容
if contains_dangerous_content(last_message):
# 添加警告消息
state["messages"].append(
AIMessage(content="检测到不安全内容,终止执行。")
)
# 跳转到结束
return {"jump_to": "end"}
return None执行顺序
python
# 中间件列表
middleware = [middleware_1, middleware_2, middleware_3]
# Before hooks: 按顺序执行
# middleware_1.before_model() → middleware_2.before_model() → middleware_3.before_model()
# After hooks: 逆序执行
# middleware_3.after_model() → middleware_2.after_model() → middleware_1.after_model()
# Wrap hooks: 嵌套执行
# middleware_1.wrap_model_call(
# middleware_2.wrap_model_call(
# middleware_3.wrap_model_call(
# actual_model_call()
# )
# )
# )实际案例
速率限制中间件
python
import time
from collections import deque
class RateLimitMiddleware(AgentMiddleware):
"""速率限制中间件"""
def __init__(self, max_requests: int = 10, window_seconds: int = 60):
self.max_requests = max_requests
self.window_seconds = window_seconds
self.request_times = deque()
def before_model(self, state, runtime):
current_time = time.time()
# 清理过期的请求记录
while self.request_times and \
self.request_times[0] < current_time - self.window_seconds:
self.request_times.popleft()
# 检查是否超过限制
if len(self.request_times) >= self.max_requests:
wait_time = self.request_times[0] + self.window_seconds - current_time
print(f"Rate limit exceeded. Waiting {wait_time:.1f}s...")
time.sleep(wait_time)
self.request_times.append(current_time)
return None成本追踪中间件
python
class CostTrackingMiddleware(AgentMiddleware):
"""成本追踪中间件"""
# 每 1K tokens 的价格(美元)
PRICING = {
"gpt-4o": {"input": 0.005, "output": 0.015},
"gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015},
}
def __init__(self):
self.total_cost = 0.0
def after_model(self, state, response, runtime):
if hasattr(response, "response_metadata"):
usage = response.response_metadata.get("usage", {})
model = response.response_metadata.get("model", "gpt-4o")
pricing = self.PRICING.get(model, self.PRICING["gpt-4o"])
input_cost = (usage.get("prompt_tokens", 0) / 1000) * pricing["input"]
output_cost = (usage.get("completion_tokens", 0) / 1000) * pricing["output"]
self.total_cost += input_cost + output_cost
print(f"Current cost: ${self.total_cost:.4f}")
return None缓存中间件
python
import hashlib
import json
class CacheMiddleware(AgentMiddleware):
"""响应缓存中间件"""
def __init__(self, cache_client):
self.cache = cache_client
def _get_cache_key(self, state):
"""生成缓存键"""
messages_str = json.dumps([
{"role": m.type, "content": m.content}
for m in state["messages"]
])
return hashlib.md5(messages_str.encode()).hexdigest()
def before_model(self, state, runtime):
"""检查缓存"""
cache_key = self._get_cache_key(state)
cached = self.cache.get(cache_key)
if cached:
print("Cache hit!")
# 返回缓存的响应,跳过模型调用
state["messages"].append(AIMessage(content=cached))
return {"jump_to": "end"}
return None
def after_model(self, state, response, runtime):
"""存储到缓存"""
if response.content:
cache_key = self._get_cache_key(state)
self.cache.set(cache_key, response.content, ttl=3600)
return None最佳实践
| 实践 | 说明 |
|---|---|
| 保持专注 | 每个中间件只做一件事 |
| 优雅降级 | 中间件错误不应导致 Agent 崩溃 |
| 避免阻塞 | 不要在钩子中执行耗时操作 |
| 记录日志 | 便于调试和监控 |
| 测试充分 | 单独测试每个中间件 |
调试技巧
python
@before_model
def debug_state(state, runtime):
"""调试:打印当前状态"""
print("=" * 50)
print(f"Messages: {len(state['messages'])}")
print(f"Last message: {state['messages'][-1].content[:100]}")
print(f"Config: {runtime.config}")
print("=" * 50)
return None